"""
Phase 4: Groneck-Kaufmann Replication & State-Space Analysis
==============================================================
Tests:
  (a) Replicate Groneck-Kaufmann on their approximate sample (15 OECD, 1970-2009)
  (b) Expand to all OECD — does the result hold?
  (c) State-space analysis: is the G-K result a function of income, openness,
      demographic stage, or service-sector share?
  (d) Robustness battery for the global refutation

The goal is to be bulletproof if we claim to partially refute G-K.

Output: output/tables/groneck_replication.md, state_space.md, robustness.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/rer")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

# Groneck-Kaufmann approximate sample: 15 core OECD countries
# They used 15 OECD countries from 1970-2009. The exact list isn't
# published clearly, but the core 15 OECD in the 1970s were:
GK_COUNTRIES = [
    'AUS', 'AUT', 'BEL', 'CAN', 'DNK', 'FIN', 'FRA', 'DEU',
    'ITA', 'JPN', 'NLD', 'NOR', 'SWE', 'GBR', 'USA',
]

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label, silent=False):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 30:
        if not silent:
            print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    if not silent:
        print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, R²={gls.r_squared:.4f}")
        for i, name in enumerate(regressors):
            sig = stars(gls.pvalues[i])
            print(f"    {name:<30} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
    return results


def build_table(results, key_vars, notes, filename, title):
    if not results:
        return
    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
    md.append(f"\n*{notes}*")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 4: Groneck-Kaufmann Replication & State-Space Analysis")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "rer_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    age_vars = ['old_dep', 'youth_dep']
    controls_bs = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_gdp_pc']
    controls_bs = [c for c in controls_bs if c in df.columns]
    controls_basic = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
    controls_basic = [c for c in controls_basic if c in df.columns]

    # ═══════════════════════════════════════════════════════════════════
    # PART A: GRONECK-KAUFMANN REPLICATION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: GRONECK-KAUFMANN REPLICATION (15 OECD, up to 2009)")
    print("=" * 70)

    gk_results = []

    # GK sample: 15 countries, up to 2009
    gk = df[(df['iso3'].isin(GK_COUNTRIES)) & (df['year'] <= 2009)].copy()
    print(f"  GK sample: {gk['iso3'].nunique()} countries, "
          f"{gk['reer_combined'].notna().sum()} REER obs")

    # A1: Z → log REER (no BS, GK sample)
    r = run_model(gk, 'log_reer_combined', demo_vars + controls_basic,
                  "A1: GK15 Z (no BS)")
    if r: gk_results.append(r)

    # A2: Z → log REER (with BS, GK sample)
    r = run_model(gk, 'log_reer_combined', demo_vars + controls_bs,
                  "A2: GK15 Z + BS")
    if r: gk_results.append(r)

    # A3: age ratios → log REER (GK sample)
    r = run_model(gk, 'log_reer_combined', age_vars + controls_bs,
                  "A3: GK15 age ratios + BS")
    if r: gk_results.append(r)

    # A4: Expand to all OECD, up to 2009
    oecd_09 = df[(df['iso3'].isin(OECD_38)) & (df['year'] <= 2009)].copy()
    r = run_model(oecd_09, 'log_reer_combined', demo_vars + controls_bs,
                  "A4: OECD38 Z + BS (≤2009)")
    if r: gk_results.append(r)

    # A5: GK15 countries but full time period (through 2024)
    gk_full = df[df['iso3'].isin(GK_COUNTRIES)].copy()
    r = run_model(gk_full, 'log_reer_combined', demo_vars + controls_bs,
                  "A5: GK15 Z + BS (full period)")
    if r: gk_results.append(r)

    # A6: All OECD, full period
    oecd_full = df[df['iso3'].isin(OECD_38)].copy()
    r = run_model(oecd_full, 'log_reer_combined', demo_vars + controls_bs,
                  "A6: OECD38 Z + BS (full)")
    if r: gk_results.append(r)

    # A7: Non-OECD only
    non_oecd = df[~df['iso3'].isin(OECD_38)].copy()
    r = run_model(non_oecd, 'log_reer_combined', demo_vars + controls_bs,
                  "A7: non-OECD Z + BS")
    if r: gk_results.append(r)

    # A8: Full global sample
    r = run_model(df, 'log_reer_combined', demo_vars + controls_bs,
                  "A8: Global Z + BS")
    if r: gk_results.append(r)

    key_gk_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep', 'log_gdp_pc']
    build_table(gk_results, key_gk_vars,
                "GK15 = AUS/AUT/BEL/CAN/DNK/FIN/FRA/DEU/ITA/JPN/NLD/NOR/SWE/GBR/USA. "
                "BS control = log(GDP per capita PPP).",
                "groneck_replication.md",
                "Groneck-Kaufmann Replication & Sample Expansion")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: STATE-SPACE ANALYSIS
    # Why does the sign flip? Test each dimension.
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: STATE-SPACE ANALYSIS — Why does the sign flip?")
    print("=" * 70)

    ss_results = []

    # B1: Income terciles
    for group in ['low', 'mid', 'high']:
        sub = df[df['income_group'] == group].copy()
        r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                      f"B1{group[0]}: {group}-income")
        if r: ss_results.append(r)

    # B2: KAOPEN terciles (capital openness)
    if 'kaopen' in df.columns:
        df['ka_tercile'] = df.groupby('year')['kaopen'].transform(
            lambda x: pd.qcut(x, 3, labels=['closed', 'mid', 'open'], duplicates='drop')
            if len(x.dropna()) >= 10 else pd.Series(np.nan, index=x.index))
        for group in ['closed', 'mid', 'open']:
            sub = df[df['ka_tercile'] == group].copy()
            r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                          f"B2{group[0]}: {group} capital")
            if r: ss_results.append(r)

    # B3: Trade openness terciles
    if 'trade_openness' in df.columns:
        df['trade_tercile'] = df.groupby('year')['trade_openness'].transform(
            lambda x: pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            if len(x.dropna()) >= 10 else pd.Series(np.nan, index=x.index))
        for group in ['low', 'mid', 'high']:
            sub = df[df['trade_tercile'] == group].copy()
            r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                          f"B3{group[0]}: {group} trade")
            if r: ss_results.append(r)

    # B4: Demographic stage (Z₁ terciles = young, transitioning, old)
    df['demo_stage'] = df.groupby('year')['Z_1'].transform(
        lambda x: pd.qcut(x, 3, labels=['young', 'trans', 'old'], duplicates='drop')
        if len(x.dropna()) >= 10 else pd.Series(np.nan, index=x.index))
    for group in ['young', 'trans', 'old']:
        sub = df[df['demo_stage'] == group].copy()
        r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                      f"B4{group[0]}: {group} demo stage")
        if r: ss_results.append(r)

    # B5: NFA position (creditor vs debtor — from Phase 3 we know this matters)
    creditor = df[df['nfa_gdp_lag'] >= 0].copy()
    debtor = df[df['nfa_gdp_lag'] < 0].copy()
    r = run_model(creditor, 'log_reer_combined', demo_vars + controls_bs,
                  "B5a: NFA creditor")
    if r: ss_results.append(r)
    r = run_model(debtor, 'log_reer_combined', demo_vars + controls_bs,
                  "B5b: NFA debtor")
    if r: ss_results.append(r)

    # B6: Health expenditure terciles (non-tradable sector size)
    if 'health_exp_gdp' in df.columns:
        df['health_tercile'] = df.groupby('year')['health_exp_gdp'].transform(
            lambda x: pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            if len(x.dropna()) >= 10 else pd.Series(np.nan, index=x.index))
        for group in ['low', 'mid', 'high']:
            sub = df[df['health_tercile'] == group].copy()
            r = run_model(sub, 'log_reer_combined', demo_vars + controls_bs,
                          f"B6{group[0]}: {group} health_exp")
            if r: ss_results.append(r)

    key_ss_vars = ['Z_1', 'Z_2', 'Z_3', 'log_gdp_pc']
    build_table(ss_results, key_ss_vars,
                "State-space decomposition: which country characteristics drive the sign of Z₁?",
                "state_space.md",
                "State-Space Analysis: Where Does Z₁ Predict REER?")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: ROBUSTNESS BATTERY
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: ROBUSTNESS BATTERY")
    print("=" * 70)

    rob_results = []

    # C1: Alternative REER measures
    for dep_var, label in [('log_reer_combined', 'C1a: level'),
                           ('d_log_reer', 'C1b: first-diff'),
                           ('log_reer_ma5', 'C1c: MA5'),
                           ('log_reer_demeaned', 'C1d: demeaned')]:
        r = run_model(df, dep_var, demo_vars + controls_bs, label)
        if r: rob_results.append(r)

    # C2: BIS-only REER (exclude IMF supplement)
    if 'reer' in df.columns and 'log_reer' in df.columns:
        r = run_model(df, 'log_reer', demo_vars + controls_bs,
                      "C2: BIS-only REER")
        if r: rob_results.append(r)

    # C3: Drop outlier countries (top/bottom 5% of REER distribution)
    reer_mean = df.groupby('iso3')['log_reer_combined'].mean()
    p5, p95 = reer_mean.quantile(0.05), reer_mean.quantile(0.95)
    non_outlier = reer_mean[(reer_mean >= p5) & (reer_mean <= p95)].index
    df_no_outlier = df[df['iso3'].isin(non_outlier)].copy()
    r = run_model(df_no_outlier, 'log_reer_combined', demo_vars + controls_bs,
                  "C3: drop REER outliers")
    if r: rob_results.append(r)

    # C4: Add inflation control (Balassa-Samuelson related)
    if 'cpi_inflation_wb' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      demo_vars + controls_bs + ['cpi_inflation_wb'],
                      "C4: + inflation control")
        if r: rob_results.append(r)

    # C5: Add gross_investment_gdp control
    if 'gross_investment_gdp' in df.columns:
        r = run_model(df, 'log_reer_combined',
                      demo_vars + controls_bs + ['gross_investment_gdp'],
                      "C5: + investment/GDP")
        if r: rob_results.append(r)

    # C6: Drop small islands (population < 1M)
    if 'population_weo' in df.columns:
        large = df[df['population_weo'] > 1].copy()
        r = run_model(large, 'log_reer_combined', demo_vars + controls_bs,
                      "C6: pop > 1M only")
        if r: rob_results.append(r)

    # C7: Exclude financial centers (HKG, SGP, LUX, IRL, CHE)
    fc = ['HKG', 'SGP', 'LUX', 'IRL', 'CHE']
    df_no_fc = df[~df['iso3'].isin(fc)].copy()
    r = run_model(df_no_fc, 'log_reer_combined', demo_vars + controls_bs,
                  "C7: excl fin centers")
    if r: rob_results.append(r)

    # C8: Only countries with >= 10 years of REER data
    years_count = df.groupby('iso3')['log_reer_combined'].apply(lambda x: x.notna().sum())
    long_panel = years_count[years_count >= 10].index
    df_long = df[df['iso3'].isin(long_panel)].copy()
    r = run_model(df_long, 'log_reer_combined', demo_vars + controls_bs,
                  "C8: ≥10yr REER data")
    if r: rob_results.append(r)

    # C9: Winsorize REER at 1%/99%
    df_w = df.copy()
    p1 = df_w['log_reer_combined'].quantile(0.01)
    p99 = df_w['log_reer_combined'].quantile(0.99)
    df_w['log_reer_winsorized'] = df_w['log_reer_combined'].clip(p1, p99)
    r = run_model(df_w, 'log_reer_winsorized', demo_vars + controls_bs,
                  "C9: winsorized REER")
    if r: rob_results.append(r)

    # C10: Use OADR directly instead of Z PCs
    r = run_model(df, 'log_reer_combined', ['old_dep'] + controls_bs,
                  "C10: OADR only")
    if r: rob_results.append(r)

    # C11: 5-year averages (reduce serial correlation)
    df['period'] = (df['year'] - 1990) // 5
    df_5yr = df.groupby(['iso3', 'period']).mean(numeric_only=True).reset_index()
    df_5yr['iso3'] = df_5yr['iso3'].astype(str)
    r = run_model(df_5yr, 'log_reer_combined', demo_vars + controls_bs,
                  "C11: 5yr averages")
    if r: rob_results.append(r)

    key_rob_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'log_gdp_pc']
    build_table(rob_results, key_rob_vars,
                "Robustness battery for the global demographic REER effect",
                "robustness.md",
                "Robustness Battery: Demographics → REER")

    # ═══════════════════════════════════════════════════════════════════
    # SUMMARY: Sign map
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("SIGN MAP SUMMARY")
    print("=" * 70)

    all_results = gk_results + ss_results + rob_results
    print(f"\n  {'Model':<40} {'Z₁ coef':>10} {'p':>8} {'Sign':>6}")
    print("  " + "-" * 70)
    for r in all_results:
        z1 = r.get('coef_Z_1', None)
        p = r.get('p_Z_1', None)
        if z1 is not None and p is not None:
            sign = "+" if z1 > 0 else "-"
            sig = stars(p)
            print(f"  {r['label']:<40} {z1:>10.4f} {p:>8.4f} {sign:>4}{sig}")

    print("\n" + "=" * 70)
    print("Phase 4 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
